/*
 * Copyright (c) 2017, Salesforce.com, Inc.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * * Redistributions of source code must retain the above copyright notice, this
 *   list of conditions and the following disclaimer.
 *
 * * Redistributions in binary form must reproduce the above copyright notice,
 *   this list of conditions and the following disclaimer in the documentation
 *   and/or other materials provided with the distribution.
 *
 * * Neither the name of the copyright holder nor the names of its
 *   contributors may be used to endorse or promote products derived from
 *   this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

package com.salesforce.op

/**
  * Created by allwefantasy on 20/9/2018.
  */

import com.salesforce.op.evaluators.{EvaluationMetrics, OpEvaluatorBase}
import com.salesforce.op.features.OPFeature
import com.salesforce.op.readers.DataFrameFieldNames._
import com.salesforce.op.stages.{OPStage, OpPipelineStage, OpTransformer}
import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.utils.stages.FitStagesUtil
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, SparkSession}

import scala.reflect.ClassTag


/**
  * Workflow model is a container and executor for the sequence of transformations that have been fit to the data
  * to produce the desired output features
  *
  * @param uid            unique identifier for this workflow model
  * @param trainingParams params that were used during model training
  */
class WowOpWorkflowModel(override val uid: String = UID[WowOpWorkflowModel], override val trainingParams: OpParams) extends OpWorkflowModel(uid, trainingParams) {


  final def setResultFeatures(_resultFeatures: Array[OPFeature]) = {
    resultFeatures = _resultFeatures
  }

  /**
    * Load up the data as specified by the data reader then transform that data using the transformers specified in
    * this workflow. We will always keep the key and result features in the returned dataframe, but there are options
    * to keep the other raw & intermediate features.
    *
    * This method optimizes scoring by grouping applying bulks of [[OpTransformer]] stages on each step.
    * The rest of the stages go are applied sequentially (as [[org.apache.spark.ml.Pipeline]] does)
    *
    * @param path                     optional path to write out the scores to a file
    * @param keepRawFeatures          flag to enable keeping raw features in the output DataFrame as well
    * @param keepIntermediateFeatures flag to enable keeping intermediate features in the output DataFrame as well
    * @param persistEveryKStages      how often to break up catalyst by persisting the data
    *                                 (applies for non [[OpTransformer]] stages only),
    *                                 to turn off set to Int.MaxValue (not recommended)
    * @param persistScores            should persist the final scores dataframe
    * @return Dataframe that contains all the columns generated by the transformers in this workflow model as well as
    *         the key and result features, along with other features if the above flags are set to true.
    *
    */
  override def score(
                      path: Option[String] = None,
                      keepRawFeatures: Boolean = WowOpWorkflowModel.KeepRawFeatures,
                      keepIntermediateFeatures: Boolean = WowOpWorkflowModel.KeepIntermediateFeatures,
                      persistEveryKStages: Int = WowOpWorkflowModel.PersistEveryKStages,
                      persistScores: Boolean = WowOpWorkflowModel.PersistScores
                    )(implicit spark: SparkSession): DataFrame = {
    val (scores, _) = scoreFn(
      keepRawFeatures = keepRawFeatures,
      keepIntermediateFeatures = keepIntermediateFeatures,
      persistEveryKStages = persistEveryKStages,
      persistScores = persistScores
    )(spark)(path)
    scores
  }

  /**
    * Load up the data as specified by the data reader then transform that data using the transformers specified in
    * this workflow. We will always keep the key and result features in the returned dataframe, but there are options
    * to keep the other raw & intermediate features.
    *
    * This method optimizes scoring by grouping applying bulks of [[OpTransformer]] stages on each step.
    * The rest of the stages go are applied sequentially (as [[org.apache.spark.ml.Pipeline]] does)
    *
    * @param evaluator                evalutator to use for metrics generation
    * @param path                     optional path to write out the scores to a file
    * @param keepRawFeatures          flag to enable keeping raw features in the output DataFrame as well
    * @param keepIntermediateFeatures flag to enable keeping intermediate features in the output DataFrame as well
    * @param persistEveryKStages      how often to break up catalyst by persisting the data
    *                                 (applies for non [[OpTransformer]] stages only),
    *                                 to turn off set to Int.MaxValue (not recommended)
    * @param persistScores            should persist the final scores dataframe
    * @param metricsPath              optional path to write out the metrics to a file
    * @return Dataframe that contains all the columns generated by the transformers in this workflow model as well as
    *         the key and result features, along with other features if the above flags are set to true.
    *         Also returns metrics computed with evaluator.
    */
  override def scoreAndEvaluate(
                                 evaluator: OpEvaluatorBase[_ <: EvaluationMetrics],
                                 path: Option[String] = None,
                                 keepRawFeatures: Boolean = WowOpWorkflowModel.KeepRawFeatures,
                                 keepIntermediateFeatures: Boolean = WowOpWorkflowModel.KeepIntermediateFeatures,
                                 persistEveryKStages: Int = WowOpWorkflowModel.PersistEveryKStages,
                                 persistScores: Boolean = WowOpWorkflowModel.PersistScores,
                                 metricsPath: Option[String] = None
                               )(implicit spark: SparkSession): (DataFrame, EvaluationMetrics) = {
    val (scores, metrics) = scoreFn(
      keepRawFeatures = keepRawFeatures,
      keepIntermediateFeatures = keepIntermediateFeatures,
      persistEveryKStages = persistEveryKStages,
      persistScores = persistScores,
      evaluator = Option(evaluator),
      metricsPath = metricsPath
    )(spark)(path)
    scores -> metrics.get
  }

  /**
    * Load up the data by the reader, transform it and then evaluate
    *
    * @param evaluator   OP Evaluator
    * @param metricsPath path to write out the metrics
    * @param spark       spark session
    * @return evaluation metrics
    */
  override def evaluate[T <: EvaluationMetrics : ClassTag](
                                                            evaluator: OpEvaluatorBase[T], metricsPath: Option[String] = None, scoresPath: Option[String] = None
                                                          )(implicit spark: SparkSession): T = {
    val (_, eval) = scoreAndEvaluate(evaluator = evaluator, metricsPath = metricsPath, path = scoresPath)
    eval.asInstanceOf[T]
  }

  override private[op] def scoreFn(
                                    keepRawFeatures: Boolean = WowOpWorkflowModel.KeepRawFeatures,
                                    keepIntermediateFeatures: Boolean = WowOpWorkflowModel.KeepIntermediateFeatures,
                                    persistEveryKStages: Int = WowOpWorkflowModel.PersistEveryKStages,
                                    persistScores: Boolean = WowOpWorkflowModel.PersistScores,
                                    evaluator: Option[OpEvaluatorBase[_ <: EvaluationMetrics]] = None,
                                    metricsPath: Option[String] = None
                                  )(implicit spark: SparkSession): Option[String] => (DataFrame, Option[EvaluationMetrics]) = {
    require(persistEveryKStages >= 1, s"persistEveryKStages value of $persistEveryKStages is invalid must be >= 1")

    // TODO: replace 'stages' with 'stagesDag'. (is a breaking change for serialization, but would simplify scoreFn)
    // Pre-compute transformations dag
    val dag = FitStagesUtil.computeDAG(resultFeatures)

    (path: Option[String]) => {
      // Generate the dataframe with raw features
      val rawData: DataFrame = generateRawData()

      // Apply the transformations DAG on raw data
      val transformedData: DataFrame = applyTransformationsDAG(rawData, dag, persistEveryKStages)

      // Save the scores
      val (scores, metrics) = saveScores(
        path = path,
        keepRawFeatures = keepRawFeatures,
        keepIntermediateFeatures = keepIntermediateFeatures,
        transformedData = transformedData,
        persistScores = persistScores,
        evaluator = evaluator,
        metricsPath = metricsPath
      )
      // Unpersist raw data, since it's not needed anymore
      rawData.unpersist()
      scores -> metrics
    }
  }


  /**
    * Function to remove unwanted columns from scored dataframe, evaluate and save results
    *
    * @param path                     optional path to write out the scores to a file
    * @param keepRawFeatures          flag to enable keeping raw features in the output DataFrame as well
    * @param keepIntermediateFeatures flag to enable keeping intermediate features in the output DataFrame as well
    * @param transformedData          transformed & scored dataframe
    * @param persistScores            should persist the final scores dataframe
    * @param evaluator                optional evaluator
    * @param metricsPath              optional path to write out the metrics to a file
    * @return cleaned up score dataframe & metrics
    */
  private def saveScores
  (
    path: Option[String],
    keepRawFeatures: Boolean,
    keepIntermediateFeatures: Boolean,
    transformedData: DataFrame,
    persistScores: Boolean,
    evaluator: Option[OpEvaluatorBase[_ <: EvaluationMetrics]],
    metricsPath: Option[String]
  )(implicit spark: SparkSession): (DataFrame, Option[EvaluationMetrics]) = {

    // Evaluate and save the metrics
    val metrics = for {
      ev <- evaluator
      res = ev.evaluateAll(transformedData)
      _ = metricsPath.foreach(spark.sparkContext.parallelize(Seq(res.toJson()), 1).saveAsTextFile(_))
    } yield res

    // Pick which features to return (always force the key and result features to be included)
    val featuresToKeep: Array[String] = (keepRawFeatures, keepIntermediateFeatures) match {
      case (true, true) => Array.empty[String] // keep everything (no `data.select` needed)
      case (true, false) => (rawFeatures ++ resultFeatures).map(_.name) :+ KeyFieldName
      case (false, true) => stages.map(_.getOutputFeatureName) :+ KeyFieldName
      case (false, false) => resultFeatures.map(_.name) :+ KeyFieldName
    }
    val scores = featuresToKeep.distinct match {
      case Array() => transformedData // keep everything (no `data.select` needed)
      case keep =>
        // keep the order of the columns the same when selecting, so the data wont be reshuffled
        val columns = transformedData.columns.filter(keep.contains).map(column)
        transformedData.select(columns: _*)
    }
    if (log.isTraceEnabled) {
      log.trace("Scores dataframe schema:\n{}", scores.schema.treeString)
      log.trace("Scores dataframe plans:\n")
      scores.explain(extended = true)
    }

    // Persist the scores if needed
    if (persistScores) scores.persist()

    // Save the scores if a path was provided
    path.foreach(scores.saveAvro(_))

    scores -> metrics
  }

}

case object WowOpWorkflowModel {

  val KeepRawFeatures = true
  val KeepIntermediateFeatures = false
  val PersistEveryKStages = 5
  val PersistScores = true

}

